from torch.utils.data import DataLoader
from torch import nn
import torch

from model.LLSTM import LLSTM
from getData import getData
import time


def train(model, train_loader, test_loader, train_size, test_size, device):
    # 创建损失函数
    loss_fn = nn.CrossEntropyLoss()
    loss_fn = loss_fn.to(device)

    # 优化器
    # learning_rate = 0.01
    learning_rate = 1e-2
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # 记录训练网络的一些参数
    # 记录训练的次数
    total_train_step = 0
    # 记录测试的次数
    total_test_step = 0
    # 训练的轮数
    epoch = 10

    start_time = time.time()
    for i in range(epoch):
        print("------第{}轮训练开始------".format(i))


        # 训练步骤开始
        for data in train_loader:
            imgs, targets = data

            # LSTM 模型需要处理输入数据 其他模型不需要
            # 使用其他模型将这部分去掉就好
            imgs = imgs.reshape(-1, 6, 69)

            imgs = imgs.to(device)
            targets = targets.to(device)
            model.train()

            # LSTM模型专属 ↓
            outputs = model(imgs, device)
            # 其他模型专属 ↓
            # outputs = model(imgs)

            loss = loss_fn(outputs, targets)
            # loss = loss.to(device)

            # 优化器调优
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_step += 1
            corrects = (torch.max(outputs, 1)[1].view(targets.size()).data == targets.data).sum()
            accuracy = float(corrects) / 10 * 100.0
            print("训练次数：{}, Loss:{:.4f}".format(total_train_step, loss.item()))
            print("训练次数：{}, Accuracy:{:.4f}".format(total_train_step, accuracy))

            model.eval()
            # 测试步骤开始
            total_test_loss = 0
            total_accuracy = 0
            with torch.no_grad():
                for data in test_loader:
                    imgs, targets = data

                    # LSTM 模型需要处理输入数据 其他模型不需要
                    # 使用其他模型将这部分去掉就好
                    imgs = imgs.reshape(-1, 6, 69)

                    imgs = imgs.to(device)
                    targets = targets.to(device)
                    # 其他模型专属 ↓
                    # outputs = model(imgs)
                    # LSTM模型专属 ↓
                    outputs = model(imgs, device)
                    loss = loss_fn(outputs, targets)
                    total_test_loss += loss.item()
                    accuracy = (outputs.argmax(1) == targets).sum()
                    total_accuracy += accuracy

            print("整体测试集上的loss:{}".format(total_test_loss))
            print("整体测试集上的正确率:{}".format(total_accuracy / test_size))

    end_time = time.time()
    print("整体训练时间:{}".format(str(end_time-start_time)))


def select_model_and_train(model, path_info, path_label):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    #     os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    #     model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    print(device)

    train_dataset, test_dataset = getData.get_train_and_test_dataset(path_info, path_label)
    train_size, test_size = len(train_dataset), len(test_dataset)
    train_loader, test_loader = DataLoader(train_dataset, batch_size=10), DataLoader(test_dataset)

    train(model, train_loader, test_loader, train_size, test_size, device)


if __name__ == "__main__":
    path_info = "../dataset/train_info.xlsx"
    path_label = "../dataset/train_label.xlsx"

    # lResNet = LResNet()
    # select_model_and_train(lResNet, path_info, path_label)
    # train(lcnn, train_loader, test_loader, train_size, test_size, device)

    lLSTM = LLSTM(69, 128, 2, 2)
    select_model_and_train(lLSTM, path_info, path_label)



